rm(list = ls())

library(readr)
library(dplyr)
library(word2vec)
library(tidytext)
library(pbmcapply)
library(stringr)

written_regdata <- read_csv("written_analysis_w2v.csv", show_col_types = FALSE) |> 
  mutate(mp_minister_relation = factor(mp_minister_relation,
                                       levels = c("opposition",
                                                  "coalition partner",
                                                  "same party")))


sent_mod_est <- function(data, K, baseline = FALSE) {
  
  topic_cols <- which(str_detect(names(data), "topic"))
  names(data)[topic_cols] <- str_c("topic", sprintf("%02d", 1:length(topic_cols)))
  
  w2v_models <- names(data)[which(grepl("^w2v_sent", names(data)))]

  w2v_coef <- lapply(w2v_models, \(x) {
    
    message(x)
    
    if(baseline == TRUE) {
      reg_form <- paste0(x, " ~ mp_minister_relation + q_gender + a_gender + age + q_from_party + parl_period")
    } else {
      
      topic_data <- readRDS(str_c("top_loads_k", K,".rds"))
      
      data$q_id <- data$id
      
      data$id <- str_c(data$q_id, "_", data$mp_id)
      
      data <- data |> 
        left_join(topic_data, by = "id")
      
      tops <- names(data)[which(str_detect(names(data), "topic"))]
      
      reg_form <- paste0(x, " ~ mp_minister_relation + q_gender + a_gender + age + q_from_party + parl_period + ",
                         paste0(tops[1:(length(tops)-1)], collapse = " + "))
    }
    
    tmp_reg <- lm(reg_form, data = data)
    
    return(tmp_reg)
    
  })
  
  names(w2v_coef) <- str_remove(w2v_models, "w2v_sent_")
  
  return(w2v_coef)
  
}




base <- sent_mod_est(written_regdata, baseline = T)
k15 <- sent_mod_est(written_regdata, K = 15)
k25 <- sent_mod_est(written_regdata, K = 25)
k35 <- sent_mod_est(written_regdata, K = 35)
k55 <- sent_mod_est(written_regdata, K = 55)
k75 <- sent_mod_est(written_regdata, K = 75)
k95 <- sent_mod_est(written_regdata, K = 95)


texreg::screenreg(base)
